from gensim.models import Word2Vec
import pdb
import pickle as pkl
import numpy as np
import re
from tokenizer import tokenize


def get_embeds():
    texts = []
    ct = 0
    with open('D:\\Work\\kusuri\\data_pi\\idlist_train_n2.txt', 'r', encoding='utf-8') as fin:

        for line in fin:
            line = line.strip().split('\t')
            if len(line)<2:
                text_words = tokenize(line[0])
                if len(text_words) > 2:
                    texts.append(text_words)
            else:
                text_words = tokenize(line[1])
                if len(text_words) > 2:
                    texts.append(text_words)

    print('Max sentence length:', max([len(text) for text in texts]))
    print('Avg sentence length:', sum([len(text) for text in texts]) / len(texts))
    print('Min sentence length:', min([len(text) for text in texts]))

    model = Word2Vec(texts, min_count=1, size=32)
    words = list(model.wv.vocab)

    word_index = {}
    ct = 1
    embedding_matrix = []
    embedding_matrix.append(np.zeros(32))
    for word in words:
        word_index[word] = ct
        ct += 1
        embedding_matrix.append(model[word])

    embedding_matrix = np.array(embedding_matrix)

    print("Vocab Size:", len(word_index))
    pkl.dump(word_index, open('D:\\Work\\kusuri\\data_pi\\word_index.pkl', 'wb'))
    model.save('D:\\Work\\kusuri\\data_pi\\word_embed.bin')
    np.save('D:\\Work\\kusuri\\data_pi\\embedding_matrix', np.array(embedding_matrix))

    max_val = np.max(embedding_matrix)
    min_val = np.min(embedding_matrix)
    embedding_matrix = (embedding_matrix - min_val) / (max_val - min_val)
    for i in range(32):
        embedding_matrix[0][i] = 0

    np.save('D:\\Work\\kusuri\\data_pi\\embedding_matrix_norm', np.array(embedding_matrix))


if __name__ == '__main__':
    get_embeds()